/**
 * Copyright (c) 2021-2022 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "unit_ecma_test.h"
#include "optimizer/ir/datatype.h"
#include "optimizer/ir/graph_cloner.h"
#include "optimizer/optimizations/cleanup.h"
#include "optimizer/optimizations/vn.h"

namespace ark::compiler {
class VNTest : public AsmTest {
public:
    VNTest() = default;
};

TEST_F(VNTest, CompareAnyTypeVNTrue)
{
    auto graph = CreateGraphDynWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            CONSTANT(4, 42);
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, -1)
        {
            INST(6, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(7, Opcode::Return).s32().Inputs(6);
        }
    }

    ASSERT_TRUE(graph->RunPass<ValNum>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());
    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        PARAMETER(0, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            CONSTANT(4, 42);
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, -1)
        {
            INST(7, Opcode::Return).s32().Inputs(2);
        }
    }

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(VNTest, CompareAnyTypeVNFalse)
{
    auto graph = CreateGraphDynWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            CONSTANT(4, 42);
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, -1)
        {
            INST(6, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(0);
            INST(7, Opcode::Return).s32().Inputs(6);
        }
    }

    auto graphOpt = GraphCloner(graph, graph->GetAllocator(), graph->GetLocalAllocator()).CloneGraph();

    ASSERT_FALSE(graph->RunPass<ValNum>());
    ASSERT_FALSE(graph->RunPass<Cleanup>());
    GraphChecker(graph).Check();

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(VNTest, CastAnyTypeVNTrue)
{
    auto graph = CreateGraphDynWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            CONSTANT(4, 42);
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, -1)
        {
            INST(6, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(0);
            INST(7, Opcode::Return).s32().Inputs(6);
        }
    }

    ASSERT_TRUE(graph->RunPass<ValNum>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());
    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        PARAMETER(0, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            CONSTANT(4, 42);
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, -1)
        {
            INST(7, Opcode::Return).s32().Inputs(2);
        }
    }

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(VNTest, CastAnyTypeVNFalse)
{
    auto graph = CreateGraphDynWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            CONSTANT(4, 42);
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, -1)
        {
            INST(6, Opcode::CastAnyTypeValue).s32().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(7, Opcode::Return).s32().Inputs(6);
        }
    }

    auto graphOpt = GraphCloner(graph, graph->GetAllocator(), graph->GetLocalAllocator()).CloneGraph();

    ASSERT_FALSE(graph->RunPass<ValNum>());
    ASSERT_FALSE(graph->RunPass<Cleanup>());
    GraphChecker(graph).Check();

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(VNTest, CastValueToAnyType)
{
    auto graph = CreateGraphDynWithDefaultRuntime();
    GRAPH(graph)
    {
        CONSTANT(0, 0);
        BASIC_BLOCK(2, -1)
        {
            INST(1, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(2, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(0);
            INST(3, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(4, Opcode::SaveState).Inputs(0).SrcVregs({0});
            INST(5, Opcode::CallStatic).s32().InputsAutoType(1, 2, 3, 4);
            INST(6, Opcode::Return).s32().Inputs(5);
        }
    }

    ASSERT_TRUE(graph->RunPass<ValNum>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());
    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        CONSTANT(0, 0);
        BASIC_BLOCK(2, -1)
        {
            INST(1, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(2, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(0);
            INST(4, Opcode::SaveState).Inputs(0).SrcVregs({0});
            INST(5, Opcode::CallStatic).s32().InputsAutoType(1, 2, 1, 4);
            INST(6, Opcode::Return).s32().Inputs(5);
        }
    }

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

}  // namespace ark::compiler
